package io.github.lee.jel.global.di.http

import java.net.InetAddress
import java.net.Socket
import javax.net.ssl.SSLSocket
import javax.net.ssl.SSLSocketFactory

private val TLS_V1_V2 = arrayOf("TLSv1.1", "TLSv1.2")

class TLSSocketFactory(base: SSLSocketFactory) : SSLSocketFactory() {

    private val delegate: SSLSocketFactory = base

    override fun getDefaultCipherSuites(): Array<String> =
        delegate.defaultCipherSuites

    override fun getSupportedCipherSuites(): Array<String> =
        delegate.supportedCipherSuites

    override fun createSocket(s: Socket?, host: String?, port: Int, autoClose: Boolean): Socket? =
        patch(delegate.createSocket(s, host, port, autoClose))

    override fun createSocket(host: String?, port: Int): Socket? =
        patch(delegate.createSocket(host, port))

    override fun createSocket(
        host: String?,
        port: Int,
        localHost: InetAddress?,
        localPort: Int
    ): Socket? = patch(delegate.createSocket(host, port, localHost, localPort))

    override fun createSocket(host: InetAddress?, port: Int): Socket? =
        patch(delegate.createSocket(host, port))

    override fun createSocket(
        address: InetAddress?,
        port: Int,
        localAddress: InetAddress?,
        localPort: Int
    ): Socket? =
        patch(delegate.createSocket(address, port, localAddress, localPort))

    private fun patch(s: Socket?): Socket? {
        if (s is SSLSocket) {
            s.enabledProtocols = TLS_V1_V2
        }
        return s
    }
}

